"""
Phase 4: Crisis Severity & Type
================================
Conditional analysis on crisis episodes: demographics → crisis depth,
duration, type (banking vs currency vs sovereign), and post-crisis recovery.
NFA creditor/debtor × crisis type cross-tabulation.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def run_regression(df, y_var, x_vars, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 30:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    names = feature_names if feature_names else x_vars
    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(names):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def run_pooled_ols(df, y_var, x_vars, label):
    """Pooled OLS for small cross-sectional samples (no FE)."""
    cols = [y_var] + x_vars
    sub = df[cols].dropna()
    if len(sub) < 15:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values
    X = np.column_stack([np.ones(len(sub)), sub[x_vars].values])

    try:
        beta = np.linalg.lstsq(X, y, rcond=None)[0]
        resid = y - X @ beta
        n, k = X.shape
        s2 = np.sum(resid**2) / (n - k)
        var_beta = s2 * np.linalg.inv(X.T @ X)
        se = np.sqrt(np.diag(var_beta))
        t_stats = beta / se
        from scipy import stats as sp_stats
        pvalues = 2 * (1 - sp_stats.t.cdf(np.abs(t_stats), df=n - k))

        ss_res = np.sum(resid**2)
        ss_tot = np.sum((y - y.mean())**2)
        r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
    except Exception as e:
        print(f"  {label}: OLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': n,
        'n_countries': sub.shape[0],  # each row is an episode
        'r_squared': r2,
        'rho': 0.0,
    }

    # Skip constant (index 0)
    print(f"\n  {label} (N={n}, R²={r2:.4f}) [Pooled OLS]")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i + 1])
        print(f"    {name:30s} {beta[i+1]:8.4f} ({se[i+1]:.4f}) {sig}")
        result[f'{name}_coef'] = beta[i + 1]
        result[f'{name}_se'] = se[i + 1]
        result[f'{name}_p'] = pvalues[i + 1]

    return result


def write_crosstab(ct, filename, title):
    """Write cross-tabulation as markdown table."""
    lines = [f"# {title}\n"]

    # Convert crosstab to markdown
    cols = ct.columns.tolist()
    header = "| | " + " | ".join(str(c) for c in cols) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in cols]) + "|"
    lines.append(header)
    lines.append(sep)

    for idx, row in ct.iterrows():
        row_str = f"| {idx} |"
        for c in cols:
            row_str += f" {row[c]} |"
        lines.append(row_str)

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 4: CRISIS SEVERITY & TYPE")
    print("=" * 70)

    df = pd.read_csv(DATA / "crises_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']

    # ── Cross-tabulations ──
    print("\n" + "=" * 50)
    print("CROSS-TABULATIONS")
    print("=" * 50)

    # NFA creditor vs debtor × crisis type
    if 'nfa_positive' in df.columns:
        df['nfa_status'] = np.where(df['nfa_positive'] == 1, 'Creditor', 'Debtor')

        crisis_sub = df[df['any_crisis'] == 1].copy()
        if len(crisis_sub) > 0:
            print(f"\n  Crisis observations: {len(crisis_sub)}")

            # Count by NFA status and crisis type
            ct_data = {}
            for ctype in ['banking_crisis', 'currency_crisis', 'sovereign_crisis']:
                ct_data[ctype.replace('_crisis', '')] = crisis_sub.groupby('nfa_status')[ctype].sum()

            ct = pd.DataFrame(ct_data)
            ct['total'] = ct.sum(axis=1)
            print("\n  NFA Status × Crisis Type:")
            print(ct)
            write_crosstab(ct, "nfa_crisis_crosstab.md",
                          "NFA Creditor/Debtor × Crisis Type (Crisis-Year Counts)")

    # Demographic tercile × crisis type
    if 'demo_tercile' in df.columns:
        crisis_sub2 = df[df['any_crisis'] == 1].copy()
        if len(crisis_sub2) > 0:
            ct_data2 = {}
            for ctype in ['banking_crisis', 'currency_crisis', 'sovereign_crisis']:
                ct_data2[ctype.replace('_crisis', '')] = (
                    crisis_sub2.groupby('demo_tercile')[ctype].sum()
                )
            ct2 = pd.DataFrame(ct_data2)
            ct2['total'] = ct2.sum(axis=1)
            print("\n  Demo Tercile × Crisis Type:")
            print(ct2)
            write_crosstab(ct2, "demo_crisis_crosstab.md",
                          "Demographic Stage × Crisis Type (Crisis-Year Counts)")

    # ── Aging → Bank Risk Channel (Doerr et al.) ──
    print("\n" + "=" * 50)
    print("AGING → BANK RISK CHANNEL")
    print("=" * 50)

    results_aging = []

    # old_dep → banking crisis
    print("\n--- old_dep → banking_crisis_onset ---")
    r = run_regression(df, 'banking_crisis_onset',
                       ['old_dep'] + controls,
                       'M1: OADR → Banking')
    if r: results_aging.append(r)

    # old_dep → banking crisis, controlling for credit growth
    if 'd_gross_liab' in df.columns:
        print("\n--- old_dep → banking_crisis_onset + credit growth ---")
        r = run_regression(df, 'banking_crisis_onset',
                           ['old_dep', 'd_gross_liab'] + controls,
                           'M2: + Credit Growth')
        if r: results_aging.append(r)

    # youth_dep → banking crisis (for comparison)
    print("\n--- youth_dep → banking_crisis_onset ---")
    r = run_regression(df, 'banking_crisis_onset',
                       ['youth_dep'] + controls,
                       'M3: Youth → Banking')
    if r: results_aging.append(r)

    if results_aging:
        write_table(results_aging, "aging_bank_risk.md",
                    "Aging and Bank Risk Channel")

    # ── Post-Crisis Recovery & Duration ──
    print("\n" + "=" * 50)
    print("POST-CRISIS RECOVERY")
    print("=" * 50)

    results_recovery = []

    # Construct cumulative output loss (5-year window after crisis onset)
    df = df.sort_values(['iso3', 'year'])

    # For each banking crisis onset, compute cumulative output gap over next 5 years
    onset_mask = df['banking_crisis_onset'] == 1
    if onset_mask.sum() > 0:
        # Build cumulative output loss variable
        df['cum_output_loss_5yr'] = np.nan

        for idx in df[onset_mask].index:
            iso3 = df.loc[idx, 'iso3']
            yr = df.loc[idx, 'year']
            future = df[(df['iso3'] == iso3) &
                        (df['year'] > yr) & (df['year'] <= yr + 5)]
            if 'output_gap' in future.columns and len(future) >= 3:
                cum_loss = future['output_gap'].sum()
                df.loc[idx, 'cum_output_loss_5yr'] = cum_loss

        # Z → cumulative output loss (crisis onset subsample)
        # Use pooled OLS — too few episodes for panel FE
        onset_df = df[df['banking_crisis_onset'] == 1].copy()
        n_with_data = onset_df['cum_output_loss_5yr'].notna().sum()
        print(f"\n  Banking crisis onsets with output data: {n_with_data}")

        if n_with_data >= 20:
            for x_vars, label in [
                (['Z_1', 'Z_2', 'Z_3'], 'M4: Z → Output Loss'),
                (['old_dep', 'youth_dep'], 'M5: Age → Output Loss'),
                (['Z_1', 'Z_2', 'Z_3', 'rgdp_growth', 'kaopen'], 'M6: Z + Controls'),
            ]:
                r = run_pooled_ols(onset_df, 'cum_output_loss_5yr', x_vars, label)
                if r: results_recovery.append(r)

    # ── Crisis duration ──
    print("\n--- Crisis Duration ---")
    # Construct crisis duration for each banking crisis episode
    episodes = pd.read_csv(DATA / "lv_episodes.csv") if (DATA / "lv_episodes.csv").exists() else None

    if episodes is not None and len(episodes) > 0:
        banking_ep = episodes[episodes['crisis_type'] == 'banking'].copy()
        banking_ep['duration'] = banking_ep['end_year'] - banking_ep['start_year'] + 1

        # Merge demographics at onset year
        banking_ep = banking_ep.merge(
            df[['iso3', 'year', 'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'nfa_gdp_lag', 'kaopen', 'rgdp_growth']],
            left_on=['iso3', 'start_year'],
            right_on=['iso3', 'year'],
            how='left'
        )

        if len(banking_ep.dropna(subset=['duration', 'Z_1'])) >= 20:
            from scipy import stats as sp_stats
            sub = banking_ep.dropna(subset=['duration', 'Z_1', 'old_dep'])
            corr_z1, p_z1 = sp_stats.pearsonr(sub['Z_1'], sub['duration'])
            corr_old, p_old = sp_stats.pearsonr(sub['old_dep'], sub['duration'])
            print(f"  Correlation(Z₁, duration): {corr_z1:.3f} (p={p_z1:.3f})")
            print(f"  Correlation(old_dep, duration): {corr_old:.3f} (p={p_old:.3f})")

            # Pooled OLS: duration = β·Z + β·controls
            r_dur_z = run_pooled_ols(sub, 'duration', ['Z_1', 'Z_2', 'Z_3'], 'Duration ~ Z')
            r_dur_age = run_pooled_ols(sub, 'duration', ['old_dep', 'youth_dep'], 'Duration ~ Age')

            # Save duration stats
            dur_lines = ["# Crisis Duration and Demographics\n"]
            dur_lines.append(f"Banking crisis episodes matched: {len(sub)}")
            dur_lines.append(f"Mean duration: {sub['duration'].mean():.1f} years")
            dur_lines.append(f"Median duration: {sub['duration'].median():.1f} years")
            dur_lines.append(f"\nCorrelation(Z₁, duration): {corr_z1:.3f} (p={p_z1:.3f})")
            dur_lines.append(f"Correlation(old_dep, duration): {corr_old:.3f} (p={p_old:.3f})")

            # Add OLS results
            for r_dur in [r_dur_z, r_dur_age]:
                if r_dur:
                    dur_lines.append(f"\n### {r_dur['model']} (N={r_dur['n_obs']}, R²={r_dur['r_squared']:.4f})")
                    for k, v in r_dur.items():
                        if k.endswith('_coef'):
                            vname = k.replace('_coef', '')
                            dur_lines.append(f"  {vname}: {v:.4f} (se={r_dur[f'{vname}_se']:.4f}, p={r_dur[f'{vname}_p']:.3f})")

            (OUT_TABLES / "crisis_duration.md").write_text('\n'.join(dur_lines))
            print(f"  Saved: {OUT_TABLES / 'crisis_duration.md'}")

    if results_recovery:
        write_table(results_recovery, "post_crisis_recovery.md",
                    "Post-Crisis Cumulative Output Loss")

    # ── Income Group Heterogeneity ──
    print("\n" + "=" * 50)
    print("INCOME GROUP HETEROGENEITY")
    print("=" * 50)

    results_income = []

    # Split by OECD / non-OECD using kaopen median as rough proxy
    # (or use actual income classification if available)
    kaopen_median = df['kaopen'].median()
    df['high_openness'] = (df['kaopen'] > kaopen_median).astype(int)

    for group_label, mask in [('High Openness', df['high_openness'] == 1),
                               ('Low Openness', df['high_openness'] == 0)]:
        sub = df[mask].copy()
        if len(sub) < 100:
            continue

        print(f"\n--- {group_label} (N={len(sub)}) ---")
        r = run_regression(sub, 'banking_crisis_onset',
                           ['Z_1', 'Z_2', 'Z_3'] + ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth'],
                           group_label)
        if r: results_income.append(r)

    if results_income:
        write_table(results_income, "crisis_by_openness.md",
                    "Crisis Prediction by Capital Account Openness")

    print("\n" + "=" * 70)
    print("PHASE 4 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
